import graphframes
import pyspark
from pyspark.ml import feature
from pyspark.ml.recommendation import ALS
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from graphframes import *
from pandas import *
from graphframes import GraphFrame
from pyspark.sql.functions import col, lit
from py4j import *
import time


# TODO 数据建模，数据分析等等
def getData():
    spark = SparkSession.builder.master("local").appName("graphframe test").getOrCreate()
    spark.sparkContext.setCheckpointDir("/checkPoint")
    url = "C:/Users/Lenovo/Desktop/Working/A毕设-Graphx/result_demo_100.csv"
    userListUrl = "C:/Users/Lenovo/Desktop/Working/A毕设-Graphx/userList_demo_100.csv"

    data = spark.read.csv(url, header=True)
    userList = spark.read.csv(userListUrl, header=True)
    data.show(10, False)
    return data, userList


def graphFrame():
    data, userList = getData()
    data.cache()
    e = data.withColumnRenamed("host", "name")
    nameIndexer = feature.StringIndexer() \
        .setInputCol("name") \
        .setOutputCol("id") \
        .fit(userList)
    v = nameIndexer.transform(userList).withColumn("id", col("id").cast(IntegerType()))

    e = nameIndexer.transform(e).withColumn("id", col("id").cast(IntegerType())).withColumnRenamed("id", "src") \
        .drop("name").withColumnRenamed("guest", "name")
    e = nameIndexer.transform(e).withColumn("id", col("id").cast(IntegerType())).withColumnRenamed("id", "dst") \
        .drop("name").drop("attitude").withColumn("relationship", lit("friend"))
    # TODO 第一次运行需要将其保存，在后续的数据加工中需要使用该csv文件，是最基础的系列文件之一，后续可被注释
    # pdf = e.toPandas()
    # pdf = pdf[["src", "dst", "relationship"]]
    # pdf.to_csv("userLink_demo_100.csv", index=False)

    print("正在组成graphFrame... ----------------------------------------------------------------------------------->\n\n")
    graphData = GraphFrame(v, e)
    # graphData.inDegrees.show()
    print("成功创建GraphFrame")

    #   查看GraphFrame
    print("查看当前GraphFrame")
    graphData.vertices.show(5, False)
    graphData.edges.show(5, False)

    #   计算每个顶点的度
    print("计算每个顶点的度")
    graphData.degrees.show(5, False)
    return graphData


def set_Als():
    # TODO 训练ALS模型
    spark = SparkSession.builder.master("local").appName("ALs test").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    url = "./ALS_score.csv"

    rdd = spark.read.csv(url, header=True)
    # rdd.show(10, True)
    host_indexer = feature.StringIndexer() \
        .setInputCol("host") \
        .setOutputCol("hostIndex") \
        .setHandleInvalid("keep") \
        .fit(rdd)
    # 可以通过包含.setHandleInvalid（“keep”）来避免错误，并且我已经检查了输出,似乎没有任何问题，没有产生string => double 的问题
    # 这属于一种”不太可能发生“的情况，有点怪

    data_01 = host_indexer.transform(rdd)
    data_01 = data_01.withColumnRenamed("hostIndex", "hostId") \
        .withColumnRenamed("host", "hostName") \
        .withColumnRenamed("guest", "host")
    data_01 = host_indexer.transform(data_01) \
        .withColumnRenamed("hostIndex", "guestId") \
        .withColumnRenamed("host", "guestName")
    data_01 = data_01.withColumn("attitude", col("attitude").cast(DoubleType()))
    trainData, testData = data_01.randomSplit([0.7, 0.3])
    print("正在创建ALS推荐算法------------------------------------------------------------------------------------------>\n")
    timeOld = time.time()
    Als_recommend = ALS() \
        .setRegParam(0.01) \
        .setMaxIter(50) \
        .setItemCol("guestId") \
        .setRatingCol("attitude") \
        .setUserCol("hostId") \
        .setRank(12)
    modelExplicit = Als_recommend.fit(trainData)
    print("ALS时间花费为:{time}s".format(time=time.time() - timeOld))

    predictionExplicit = modelExplicit.transform(testData).na.drop()
    # predictionExplicit.show(5, False)
    print("正在为用户推荐Top5\n")
    a = modelExplicit.recommendForAllUsers(5)

    splitDF = a.select(col("hostId"),
                       col("recommendations").getItem(0).alias("top1"),
                       col("recommendations").getItem(1).alias("top2"),
                       col("recommendations").getItem(2).alias("top3"),
                       col("recommendations").getItem(3).alias("top4"),
                       col("recommendations").getItem(4).alias("top5"))

    # splitDF = splitDF.withColumn("top1", col("top1").getField("guestId"))

    for i in range(1, 6):
        splitDF = splitDF.withColumn("top{}".format(i), col("top{}".format(i)).getField("guestId"))
    splitDF = splitDF.withColumnRenamed("hostId", "id")
    return splitDF


def als_Modify(graphData, nodes, user_pageRank):
    print("查询搜索，可能认识的人--------------------------------------------------------》\n")
    # TODO find方法查询可推荐的 “你可能认识的人” 的名单
    # Motif: A->B->C but not A->C
    results = graphData.find("(A)-[]->(B); (B)-[]->(C); !(A)-[]->(C)")
    # Filter out loops (with DataFrame operation)
    results = results.filter("A.id != C.id")
    # Select recommendations for A to follow C
    results = results.select(results.A.id.alias("id"), results.C.id.alias("Top"))
    # results.sort("id").show()

    # TODO 用不同用户类型分开采取推荐算法
    """
    对pagerank系数采取2-8定律，将用户分开
    对低pagerank用户采取modify模式推荐用户用于巩固用户的稳定性
    对高pagerank用户采取Als推荐算法模式推荐新用户，扩展新社交圈子
    """
    percent_20 = int(nodes.pagerank.count() * 0.8)
    rank = nodes.pagerank.values[percent_20]
    als_ID = user_pageRank.filter(col("pagerank") > rank).withColumnRenamed("id", "srcId")
    modify_ID = user_pageRank.filter(col("pagerank") <= rank).withColumnRenamed("id", "srcId")
    print(str(modify_ID.count()) + "用户接受了modify推荐")
    modify_ID.show()
    print("‘边缘化’人物的Modify推荐模式的推荐用户如下------------------------------------------------------------------->\n")
    modifyDF = results.join(modify_ID, modify_ID.srcId == results.id, "inner").drop("id")
    modifyDF = modifyDF.select("srcId", "Top").join(user_pageRank, modifyDF.Top == user_pageRank.id, "inner").drop("id") \
        .withColumnRenamed("Top", "recommendID")
    modifyDF.show(5, False)

    print("常规用户的Als推荐模式的推荐用户如下-------------------------------------------------------------------------->\n")
    alsDF = set_Als()
    alsDF = alsDF.join(als_ID, als_ID.srcId == alsDF.id, "inner").drop("id").select("srcId", "top1", "top2", "top3",
                                                                                    "top4", "top5")
    alsDF.show(5, False)
    modifyDF = modifyDF.distinct()

    try:
        prop = {
            'user': 'root',
            'password': 'z9633352',
            'driver': 'com.mysql.jdbc.Driver'}
        url = 'jdbc:mysql://localhost:3306/python_db'
        print("正在保存在mysql数据库中----------------------------------------------------------------------------->\n")
        modifyDF.write.jdbc(url=url, table="modifyRecommend", mode="overwrite", properties=prop)
        alsDF.write.jdbc(url=url, table="alsRecommend", mode="overwrite", properties=prop)
        print("保存成功------------------------------------------------------------------------------------------>\n")
    except:
        print("保存失败----------------------------------------------------------------------------------------->\n")
    return alsDF, modifyDF, rank


def testGraph():
    # TODO 本次实验分为三大主要内容，不同板块的代码可以被注释，不同板块的代码不会互相影响
    """
        1.基础数据的计算包括
            Pagerank（计算用户在社交网络中的重要性大小），
            LPA（寻找社交网络中存在的社区分布），
            三角形计数（寻找个用户在整个社交圈子中存在的三角形数量），
            component（计算社交圈子的连通分量【衡量社交网络中存在的最大的社交子图】）
        2.基于基础数据计算对整个社交网络的进一步分析
            2.1数据链路预测（用户推荐，即社交网络中的关系变化分析）
                前置条件：需要得到PageRank算法所得出的用户影响力权重大小
                本编译器即可完成全部内容
            2.2可视化处理（Echarts可视化建立，需要js语言的支持，建议使用VS Code快捷方便）
                前置条件：需要得到PageRank算法所得出的用户影响力权重大小
                需要Echarts.js组件和VS Code编译器
    :return: 社交网络的属性
    """
    graphData = graphFrame()
    # graphData.filterVertices("id == 3").filterEdges("relationship == 'friend'").edges.show()

    # TODO pageRank算法计算关键顶点
    print("计算关键顶点--------------------------------------------------------》\n")
    results = graphData.pageRank(resetProbability=0.15, maxIter=80)
    userPagerank = results.vertices.sort("pagerank", ascending=False)
    userPagerank.show(10, False)
    linkPagerank = results.vertices.sort("pagerank", ascending=False).show()
    linkPagerank.show(10, False)
    # svdPlus = results.svdPlusPlus()
    # svdPlus.show()

    # TODO 生成交叉模式推荐系统,需要前置PageRank算法生成的结果results
    results = results.vertices.sort("pagerank", ascending=False)
    nodes = results.toPandas()
    nodes = nodes[["id", "name", "age", "occupation", "pagerank"]]
    # 推荐算法的前置条件，需要保存为csv文件对pageRank数据进一步加工，到Data_Processing.py对数据加工
    nodes.to_csv("userList_demo_100_ID.csv", index=False)
    # todo可以在本文件直接计算 Data_Processing.py 文件内容
    import Data_Processing as Processing
    Processing.setJson()  # 可视化使用的是100人的Demo测试
    Processing.setALScore()  # 可使用demo也可以使用整个社交网络

    # 26号是demo测试，实际上推荐名单在als_Modify方法中就已经存入在MYSQL数据库中，实际业务中在数据库、缓存表中寻找推荐名单
    alsDF, modifyDF, rank = als_Modify(graphData, nodes, results)

    recommendId = 26
    df = nodes[nodes["id"] == recommendId]
    print("需要推荐的26号用户------------->")
    print(df)
    if df.pagerank.values[0] > rank:
        alsDF.filter(col("srcId") == recommendId).show()
    else:
        modifyDF.filter(col("srcId") == recommendId).sort("pagerank", ascending=False).limit(5).show()

    # TODO LAP算法检测网络的社区
    print("检测网络存在的社区--------------------------------------------------------》\n")
    results = graphData.labelPropagation(maxIter=10)
    results.sort("label").show()
    results.select("id", "label").groupBy("label").count().show()

    # TODO 三角形计数
    print("计算网络中存在的三角形----------------------------------------------------->\n")
    results = graphData.triangleCount()
    results.show()

    # TODO 计算连通分量
    timeOld = time.time()
    results = graphData.connectedComponents(algorithm="graphx")
    print("graphx用时花费了{time}".format(time=time.time() - timeOld))
    results.groupBy("component").count().show()
